|
| 1 | +from typing import Any |
| 2 | + |
1 | 3 | from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
2 |
| -from langchain.prompts import ChatPromptTemplate |
| 4 | +from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder |
| 5 | +from langchain.schema import HumanMessage |
| 6 | +from langchain.schema.runnable import RunnableLambda |
3 | 7 |
|
4 | 8 | SPAM_INSTRUCTIONS = """
|
5 | 9 | # Role and goal
|
|
120 | 124 | )
|
121 | 125 |
|
122 | 126 |
|
123 |
| -spam_prompt = ChatPromptTemplate( |
| 127 | +spam_prompt_with_human_message = ChatPromptTemplate( |
124 | 128 | (
|
125 | 129 | ("system", SPAM_INSTRUCTIONS),
|
126 |
| - ("human", USER_QUESTION), |
| 130 | + MessagesPlaceholder("human_message"), |
127 | 131 | )
|
128 | 132 | ).partial(format_instructions=spam_parser.get_format_instructions())
|
129 | 133 |
|
|
134 | 138 | ("human", USER_QUESTION),
|
135 | 139 | )
|
136 | 140 | ).partial(format_instructions=topic_parser.get_format_instructions())
|
| 141 | + |
| 142 | + |
| 143 | +def add_human_message(inputs: dict) -> dict: |
| 144 | + """ |
| 145 | + Adds the human message to the inputs dict and returns it. Ensures |
| 146 | + that the human message includes image URL's if they're present. |
| 147 | + """ |
| 148 | + image_urls = inputs.pop("image_urls", None) |
| 149 | + |
| 150 | + content: list[dict[str, Any]] = [dict(type="text", text=USER_QUESTION.format(**inputs))] |
| 151 | + |
| 152 | + if image_urls: |
| 153 | + for image_url in image_urls: |
| 154 | + content.append(dict(type="image_url", image_url=dict(url=image_url))) |
| 155 | + |
| 156 | + inputs.update(human_message=[HumanMessage(content=content)]) |
| 157 | + return inputs |
| 158 | + |
| 159 | + |
| 160 | +spam_prompt = RunnableLambda(add_human_message) | spam_prompt_with_human_message |
0 commit comments